import functools
import gc

import backoff
import cachetools
import numpy as np
import openai
from beartype import beartype

from .caching import delegating_cached, shelve_cached
from .token_logprobs import ABSENT_TOKEN_LOGPROB, TokenLogprobs, TokenSet

CACHE_next_token_logprobs_openai_cached = cachetools.LRUCache(maxsize=1000)  # Should also be on-disk
CACHE_next_token_logprobs_hf_cached = cachetools.LRUCache(maxsize=1000)
CACHE_next_token_logprobs = cachetools.LRUCache(maxsize=16)  # Small on purpose

"""
Chaching points:

- next_token_logprobs_openai(prompt, model) -> few tokens (~1kB)
  - small, cost-efficient +++
- next_token_logprobs_hf(prompt, model) -> many logits
  - huge (60kB+), not worth saving ..., good to have in-mem cache (~10k)
- next_token_logprobs(prompt, model, top_n=None, min_logprob=None) -> TokenLogprobs
  - just a small in-mem cache to prevent quick re-computation of TokenLogprobs
  


"""


class OpenAIMissingLogprobsError(Exception):
    """Exception raised when OpenAI API returns an invalid response."""

    pass


def is_openai_model(model_name_or_path) -> bool:
    if model_name_or_path.startswith("gpt2") or "/" in model_name_or_path:
        return False
    assert model_name_or_path in openai_models()
    return True


@functools.lru_cache
def openai_models():
    """Returns the list of available openai models. Needs API key."""
    import openai

    return [m["id"] for m in openai.Model.list()["data"]]


@backoff.on_exception(backoff.expo, (OpenAIMissingLogprobsError, openai.APIError), max_time=180)
def next_token_logprobs_openai(prompt: str, model: str) -> tuple[list[str], np.ndarray]:
    """Top possible continuations of a given prompt. Retries up to a minute on network or response errors."""
    import openai

    assert model in openai_models()
    r = openai.Completion.create(model=model, prompt=prompt, max_tokens=5, logprobs=100)
    try:
        tok_lps = r["choices"][0]["logprobs"]["top_logprobs"][0]
    except IndexError:
        raise OpenAIMissingLogprobsError(f"OpenAI API response missing logprobs")
    tok_lps = sorted(tok_lps.items(), key=lambda x: x[1], reverse=True)
    return [x[0] for x in tok_lps], np.array([x[1] for x in tok_lps], dtype=np.float32)


@delegating_cached(lambda: CACHE_next_token_logprobs_openai_cached)
def next_token_logprobs_openai_cached(prompt: str, model: str) -> np.ndarray:
    return next_token_logprobs_openai(prompt, model)


@functools.lru_cache(maxsize=1)
@beartype
def load_hf_model_memcached(model_name_or_path: str, revision="main") -> tuple:
    """
    Loads a Huggingspace hub model and its tokenizer.

    Cached so that the last model is kept in memory, and when a different model is
    requested, the old one is dropped from the cache to save GPU RAM.
    (No other reference to the model must exist elsewhere, though.)
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from transformers import GPT2Tokenizer, GPT2Model
    import torch

    load_hf_model_memcached.cache_clear()
    gc.collect()
    print(f"Loading model {model_name_or_path}, revision={revision!r} ...")

    if model_name_or_path.startswith("gpt2"):
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.float16)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", revision=revision)
    return tokenizer, model


def next_token_logprobs_hf(prompt: str, model: str) -> np.ndarray:
    """Compute and return the logprobs of all tokenizer tokens."""
    import torch

    if "#" in model:
        model, revision = model.split("#")
    else:
        revision = "main"
    tokenizer, hf_model = load_hf_model_memcached(model, revision=revision)
    kwargs = {}
    if "gpt2" in model:
        kwargs["pad_token_id"] = tokenizer.eos_token_id
    # TODO: Add CPU-inferred models; detect model type
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    # print(f"{model!r} {prompt[:20]!r}...{prompt[-20:]!r}")
    logprobs = hf_model.generate(
        inputs=input_ids,
        max_new_tokens=1,
        output_scores=True,
        return_dict_in_generate=True,
        do_sample=False,
        **kwargs,
    ).scores[0][0]
    return torch.nn.functional.log_softmax(logprobs, dim=0).tolist()


@delegating_cached(lambda: CACHE_next_token_logprobs_hf_cached)
def next_token_logprobs_hf_cached(prompt: str, model: str) -> np.ndarray:
    return next_token_logprobs_hf(prompt, model)


# The cache should be enough for trying various temperatures (<10), subprompts (<10) etc.
# While fitting in approx 200 MB of RAM even with 100k-token logprobs
@delegating_cached(lambda: CACHE_next_token_logprobs)
def next_token_logprobs(
    prompt: str,
    model: str,
    top_n: int = None,
    min_logprob: float = None,
    restrict_to: tuple[str, ...] = None,
) -> TokenLogprobs:
    if not is_openai_model(model):
        logprobs = np.float32(next_token_logprobs_hf_cached(prompt, model))
        tokenset = TokenSet.from_model(model)
        tlp = TokenLogprobs(logprobs=logprobs, tokens=tokenset, others_logprob=ABSENT_TOKEN_LOGPROB)
    else:
        tokenset, logprobs = next_token_logprobs_openai_cached(prompt, model)
        tlp = TokenLogprobs(logprobs=np.float32(logprobs), tokens=tokenset, sort=True)
        ## Adjust UB for other tokens by lowest-likelihood seen token
        # tlp.others_logprob = min(tlp.others_logprob, np.min(tlp.logprobs))
    if restrict_to is not None:
        restrict_to = TokenSet(restrict_to)
        tlp = TokenLogprobs(
            restrict_to,
            tlp.logprobs_for_tokenset(restrict_to),
            # others_logprob=tlp.others_logprob,
        )
    if top_n is not None or min_logprob is not None:
        tlp = tlp.top_tokens(n=top_n, min_logprob=min_logprob)
    return tlp


def next_token_logprobs_cache_info() -> np.ndarray:
    """Returns sums of [hits, misses, maxsize, currsize] for both caches as a numpy array."""
    return sum(
        np.array(c, dtype=np.int64)
        for c in [
            next_token_logprobs_hf_cached.cache_info(),
            next_token_logprobs_openai_cached.cache_info(),
        ]
    )
